library(sf)
library(raster)
library(terra)
library(ggplot2)
library(viridis)
library(scales)
library(cowplot)
library(osmdata)
library(tidyverse)

# ============================================================================
# FOUR PANEL VISUALIZATION: Mexico City Zoom (Horizontal Layout)
# ============================================================================
# Panel A: Original 1km population grid
# Panel B: Aggregated 10km grid with population
# Panel C: Centroids with 7070m search radius
# Panel D: Road network context

# ============================================================================
# Load and prepare data
# ============================================================================
cat("Loading data...\n")

# Load Mexico states
mexico_states <- st_read("mx-states.shp", quiet = TRUE)

# Load original 1km WorldPop raster
r_original <- raster("mx_pd_2020_1km_Aggregated.tif")

# Define Mexico City bounding box (approximate)
# Mexico City coordinates: ~19.4N, -99.1W
mexico_city_bbox <- st_bbox(c(xmin = -99.3, xmax = -98.9, 
                              ymin = 19.2, ymax = 19.6), 
                            crs = 4326)
mexico_city_extent <- extent(mexico_city_bbox[c("xmin", "xmax", "ymin", "ymax")])

# Crop original raster to Mexico City
cat("Cropping to Mexico City area...\n")
r_city_1km <- crop(r_original, mexico_city_extent)

# Aggregate to 10km
cat("Aggregating to 10km...\n")
r_city_10km <- aggregate(r_city_1km, fact = 10, fun = sum)

# ============================================================================
# PANEL A: Original 1km grid
# ============================================================================
cat("Creating Panel A: 1km resolution...\n")

# Convert to dataframe
pop_1km_df <- as.data.frame(rasterToPoints(r_city_1km))
names(pop_1km_df) <- c("x", "y", "population")
pop_1km_df <- pop_1km_df %>% filter(population > 0)

panel_a <- ggplot() +
  geom_tile(data = pop_1km_df, 
            aes(x = x, y = y, fill = population)) +
  scale_fill_viridis(option = "magma", 
                     trans = "log10",
                     name = "Pop.",
                     labels = comma) +
  geom_sf(data = st_crop(mexico_states, mexico_city_bbox),
          fill = NA, colour = "white", linewidth = 0.5, alpha = 0.7) +
  coord_sf(xlim = c(mexico_city_bbox["xmin"], mexico_city_bbox["xmax"]),
           ylim = c(mexico_city_bbox["ymin"], mexico_city_bbox["ymax"])) +
  labs(title = "A. Original 1km Grid") +
  theme_minimal() +
  theme(legend.position = "right",
        legend.key.width = unit(0.3, "cm"),
        legend.key.height = unit(0.8, "cm"),
        plot.title = element_text(face = "bold", size = 10),
        axis.title = element_blank(),
        axis.text = element_text(size = 7),
        panel.grid = element_line(colour = "grey90", linewidth = 0.2))

# ============================================================================
# PANEL B: Aggregated 10km grid with borders
# ============================================================================
cat("Creating Panel B: 10km aggregated grid...\n")

# Convert aggregated raster to dataframe
pop_10km_df <- as.data.frame(rasterToPoints(r_city_10km))
names(pop_10km_df) <- c("x", "y", "population")
pop_10km_df <- pop_10km_df %>% 
  filter(population > 0) %>%
  mutate(selected = population > 100)

# Create grid cell polygons for visualization
raster_res <- res(r_city_10km)[1]

create_cell_polygon <- function(x, y, res) {
  half_res <- res / 2
  coords <- matrix(c(
    x - half_res, y - half_res,
    x + half_res, y - half_res,
    x + half_res, y + half_res,
    x - half_res, y + half_res,
    x - half_res, y - half_res
  ), ncol = 2, byrow = TRUE)
  st_polygon(list(coords))
}

grid_cells_sf <- pop_10km_df %>%
  rowwise() %>%
  mutate(geometry = list(create_cell_polygon(x, y, raster_res))) %>%
  ungroup() %>%
  st_as_sf(crs = 4326)

panel_b <- ggplot() +
  geom_sf(data = grid_cells_sf,
          aes(fill = population),
          colour = "grey30",
          linewidth = 0.5) +
  scale_fill_viridis(option = "viridis",
                     trans = "log10",
                     name = "Pop.",
                     labels = comma) +
  geom_sf(data = st_crop(mexico_states, mexico_city_bbox),
          fill = NA, colour = "white", linewidth = 0.7, alpha = 0.8) +
  coord_sf(xlim = c(mexico_city_bbox["xmin"], mexico_city_bbox["xmax"]),
           ylim = c(mexico_city_bbox["ymin"], mexico_city_bbox["ymax"])) +
  labs(title = "B. Aggregated 10km Grid") +
  theme_minimal() +
  theme(legend.position = "right",
        legend.key.width = unit(0.3, "cm"),
        legend.key.height = unit(0.8, "cm"),
        plot.title = element_text(face = "bold", size = 10),
        axis.title = element_blank(),
        axis.text = element_text(size = 7),
        panel.grid = element_line(colour = "grey90", linewidth = 0.2))

# ============================================================================
# PANEL C: Centroids with 7070m search radius
# ============================================================================
cat("Creating Panel C: Centroids with search radius...\n")

# Create centroids from selected cells
selected_cells <- pop_10km_df %>%
  filter(selected)

centroids_sf <- st_as_sf(selected_cells,
                         coords = c("x", "y"),
                         crs = 4326)

# Create 7070m search circles
search_radius <- 7070
mexico_crs <- "EPSG:32614"  # UTM Zone 14N for Mexico City

centroids_utm <- st_transform(centroids_sf, crs = mexico_crs)
search_circles <- st_buffer(centroids_utm, dist = search_radius)
search_circles <- st_transform(search_circles, crs = 4326)

# Get selected grid cells for background
selected_grid_sf <- grid_cells_sf %>% filter(selected)

panel_c <- ggplot() +
  # Background: selected grid cells
  geom_sf(data = selected_grid_sf,
          aes(fill = population),
          colour = "grey50",
          linewidth = 0.3,
          alpha = 0.5) +
  scale_fill_viridis(option = "viridis",
                     trans = "log10",
                     name = "Pop.",
                     labels = comma) +
  # Search circles
  geom_sf(data = search_circles,
          fill = alpha("red", 0.15),
          colour = "red",
          linewidth = 0.6) +
  # Centroids
  geom_sf(data = centroids_sf,
          colour = "darkred",
          size = 1.5,
          shape = 19) +
  # State boundaries
  geom_sf(data = st_crop(mexico_states, mexico_city_bbox),
          fill = NA, colour = "grey30", linewidth = 0.7, alpha = 0.5) +
  coord_sf(xlim = c(mexico_city_bbox["xmin"], mexico_city_bbox["xmax"]),
           ylim = c(mexico_city_bbox["ymin"], mexico_city_bbox["ymax"])) +
  labs(title = "C. Search Radius (7070m)") +
  theme_minimal() +
  theme(legend.position = "right",
        legend.key.width = unit(0.3, "cm"),
        legend.key.height = unit(0.8, "cm"),
        plot.title = element_text(face = "bold", size = 10),
        axis.title = element_blank(),
        axis.text = element_text(size = 7),
        panel.grid = element_line(colour = "grey90", linewidth = 0.2))

# ============================================================================
# PANEL D: Road network context
# ============================================================================
cat("Creating Panel D: Road network context...\n")
cat("  Fetching OpenStreetMap data (this may take a moment)...\n")

# Query OSM for major roads in Mexico City
tryCatch({
  roads_query <- opq(bbox = c(mexico_city_bbox["xmin"], 
                              mexico_city_bbox["ymin"],
                              mexico_city_bbox["xmax"], 
                              mexico_city_bbox["ymax"])) %>%
    add_osm_feature(key = "highway", 
                    value = c("motorway", "trunk", "primary", "secondary")) %>%
    osmdata_sf()
  
  roads_sf <- roads_query$osm_lines
  
  if(!is.null(roads_sf) && nrow(roads_sf) > 0) {
    cat(sprintf("  Retrieved %d road segments\n", nrow(roads_sf)))
    
    panel_d <- ggplot() +
      # Background: grid cells
      geom_sf(data = selected_grid_sf,
              fill = "grey95",
              colour = "grey70",
              linewidth = 0.3) +
      # Roads
      geom_sf(data = roads_sf,
              aes(colour = highway),
              linewidth = 0.5,
              alpha = 0.7) +
      scale_colour_manual(
        name = "Road Type",
        values = c("motorway" = "#e41a1c",
                   "trunk" = "#ff7f00",
                   "primary" = "#ffff33",
                   "secondary" = "#a6cee3"),
        labels = c("Motorway", "Trunk", "Primary", "Secondary")
      ) +
      # Centroids
      geom_sf(data = centroids_sf,
              colour = "darkblue",
              size = 1.5,
              shape = 19) +
      # State boundaries
      geom_sf(data = st_crop(mexico_states, mexico_city_bbox),
              fill = NA, colour = "black", linewidth = 0.7) +
      coord_sf(xlim = c(mexico_city_bbox["xmin"], mexico_city_bbox["xmax"]),
               ylim = c(mexico_city_bbox["ymin"], mexico_city_bbox["ymax"])) +
      labs(title = "D. Road Network Context") +
      theme_minimal() +
      theme(legend.position = "right",
            legend.key.width = unit(0.3, "cm"),
            legend.key.height = unit(0.5, "cm"),
            legend.text = element_text(size = 7),
            legend.title = element_text(size = 8, face = "bold"),
            plot.title = element_text(face = "bold", size = 10),
            axis.title = element_blank(),
            axis.text = element_text(size = 7),
            panel.grid = element_line(colour = "grey90", linewidth = 0.2))
  } else {
    cat("  No roads found, creating alternative panel...\n")
    panel_d <- ggplot() +
      annotate("text", x = 0.5, y = 0.5, 
               label = "Road data unavailable", size = 5) +
      theme_void()
  }
}, error = function(e) {
  cat("  Error fetching OSM data, creating alternative panel...\n")
  cat(sprintf("  Error: %s\n", e$message))
  
  panel_d <- ggplot() +
    geom_sf(data = selected_grid_sf,
            aes(fill = population),
            colour = "grey50",
            linewidth = 0.3) +
    scale_fill_viridis(option = "viridis",
                       trans = "log10",
                       name = "Pop.",
                       labels = comma) +
    geom_sf(data = centroids_sf,
            colour = "darkblue",
            size = 1.5,
            shape = 19) +
    geom_sf(data = st_crop(mexico_states, mexico_city_bbox),
            fill = NA, colour = "black", linewidth = 0.7) +
    coord_sf(xlim = c(mexico_city_bbox["xmin"], mexico_city_bbox["xmax"]),
             ylim = c(mexico_city_bbox["ymin"], mexico_city_bbox["ymax"])) +
    labs(title = "D. Sampling Points",
         subtitle = "Blue points: query locations") +
    theme_minimal() +
    theme(legend.position = "right",
          legend.key.width = unit(0.3, "cm"),
          legend.key.height = unit(0.8, "cm"),
          plot.title = element_text(face = "bold", size = 10),
          plot.subtitle = element_text(size = 8, colour = "grey40"),
          axis.title = element_blank(),
          axis.text = element_text(size = 7),
          panel.grid = element_line(colour = "grey90", linewidth = 0.2))
})

# ============================================================================
# Combine panels HORIZONTALLY (2x2 grid)
# ============================================================================
cat("Combining panels horizontally...\n")

# Arrange in single row
combined_plot <- plot_grid(
  panel_a, panel_b, panel_c, panel_d,
  ncol = 4,
  nrow = 1,
  align = "h"
)

# Add overall title
title_text <- "Geographic Sampling Framework: Mexico City Detail"
subtitle_text <- sprintf(
  "%d grid cells | %d selected (>100 inhabitants) | %d search circles | Population: %s",
  nrow(pop_10km_df),
  sum(pop_10km_df$selected),
  nrow(search_circles),
  format(round(sum(selected_cells$population)), big.mark = ",")
)

final_plot <- ggdraw() +
  draw_plot(combined_plot, y = 0, height = 0.92) +
  draw_label(title_text, x = 0.5, y = 0.98, 
             hjust = 0.5, vjust = 1, size = 14, fontface = "bold") +
  draw_label(subtitle_text, x = 0.5, y = 0.95,
             hjust = 0.5, vjust = 2, size = 9, colour = "grey40")

print(final_plot)

# Save
ggsave("mexico_city_sampling_4panel.png", final_plot,
       width = 15, height = 4, dpi = 300, bg = "white")
ggsave("mexico_city_sampling_4panel.pdf", final_plot,
       width = 15, height = 4, bg = "white")

cat("\nFour-panel horizontal visualization complete!\n")
cat(sprintf("Mexico City area: %d cells, %d selected (%.1f%%)\n",
            nrow(pop_10km_df),
            sum(pop_10km_df$selected),
            100 * sum(pop_10km_df$selected) / nrow(pop_10km_df)))